import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from sklearn import preprocessing
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_auc_score
import sys,os
sys.path.append(r"/home/yh579/GAFM/GAFM/models")
from bases import FirstNet,SecondNet,torch_auc,totalvaraition,Attacks
# SplitNN
import torch
device='cpu'
hidden_dim = 10
gamma=gamma_w=1

class DisNet(nn.Module):
    def __init__(self):
        super(DisNet, self).__init__()
        self.L1 = nn.Linear(1,
                            hidden_dim)
        self.L2 = nn.Linear(hidden_dim,
                            hidden_dim)
        self.L3 = nn.Linear(hidden_dim,
                            1)

    def forward(self, x):
        x = self.L1(x)
        x = nn.functional.leaky_relu(x)
        # x = torch.sigmoid(x)
        x = self.L2(x)
        # x = torch.sigmoid(x)
        x = nn.functional.leaky_relu(x)
        x = self.L3(x)
        x = nn.functional.leaky_relu(x)  # ,negative_slope=3 ,negative_slope=10

        return x


def addeNoise(sigma, Y):
    # noise = np.random.uniform(0,1,N)
    noise = np.random.normal(0, sigma, Y.shape[0])
    noise = noise + Y
    return torch.Tensor(noise).reshape(-1, 1)


from torch.autograd import Variable


def GAFM_pertub(grad_recon, discriminator, server):
    b = grad_recon
    b.retain_grad()
    b = Variable(b, requires_grad=True)
    discriminator.eval()
    server.eval()
    z = -discriminator(server(b))
    # z=-discriminator(server(b))#nn.functional.leaky_relu(params[0][0]*b+params[1])
    # print('z',z)
    # print('z?',z==-discriminator(grad_recon))
    z.sum().backward(retain_graph=True)

    final_grad = b.grad.clone().detach()  # .reshape(-1,1)
    final_grad = torch.where(
        torch.isnan(final_grad),
        torch.full_like(final_grad, 0),
        final_grad)

    return final_grad  # .detach().requires_grad_()


def Pen_pertub(grad_recon, labels,delta, Y_dot=None):
    b = grad_recon
    b.retain_grad()
    b = Variable(b, requires_grad=True)
    labels = torch.abs(labels - 0.5 + delta)
    # server.eval()
    if Y_dot is not None:
        labels = Y_dot  # torch.abs(labels-0.5+delta)
    else:
        labels = torch.abs(labels - 0.5 + delta)
    z = -(labels * torch.log((b)) + (1 - labels) * torch.log(1 - (b)))

    z.sum().backward(retain_graph=True)
    final_grad = b.grad.clone().detach()  #
    final_grad = torch.where(
        torch.isnan(final_grad),
        torch.full_like(final_grad, 0),
        final_grad)
    # print('grad_recon.shape',grad_recon.shape,grad_recon.shape[0])

    return final_grad  # +torch.rand(grad_recon.shape)#.detach().requires_grad_()


# SplitNN
import torch


class Client_GAFM(torch.nn.Module):
    def __init__(self, client_model):
        super().__init__()
        """class that expresses the Client on SplitNN
        Args:
            client_model (torch model): client-side model
        Attributes:
            client_model (torch model): cliet-side model
            client_side_intermidiate (torch.Tensor): output of
                                                     client-side model
            grad_from_server
        """

        self.client_model = client_model
        self.client_side_intermidiate = None
        self.grad_from_server = None

    def forward(self, inputs):
        """client-side feed forward network
        Args:
            inputs (torch.Tensor): the input data
        Returns:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model which the client sent
                                                   to the server
        """

        self.client_side_intermidiate = self.client_model(inputs)
        # send intermidiate tensor to the server
        intermidiate_to_server = self.client_side_intermidiate.detach().requires_grad_()

        return intermidiate_to_server

    def client_backward(self, grad_from_server):
        """client-side back propagation
        Args:
            grad_from_server: gradient which the server send to the client
        """
        self.grad_from_server = grad_from_server
        self.client_side_intermidiate.backward(grad_from_server)

    def train(self):
        self.client_model.train()

    def eval(self):
        self.client_model.eval()


class Server_GAFM(torch.nn.Module):
    def __init__(self, server_model):
        super().__init__()
        """class that expresses the Server on SplitNN
        Args:
            server_model (torch model): server-side model
        Attributes:
            server_model (torch model): server-side model
            intermidiate_to_server:
            grad_to_client
        """
        self.server_model = server_model

        self.intermidiate_to_server = None
        self.grad_to_client = None
        # self.intermidiate_to_server_pertub=None

    def forward(self, intermidiate_to_server):
        """server-side training
        Args:
            intermidiate_to_server (torch.Tensor): the output of client-side
                                                   model
        Returns:
            outputs (torch.Tensor): outputs of server-side model
        """
        self.intermidiate_to_server = intermidiate_to_server
        outputs = self.server_model(intermidiate_to_server)
        # self.intermidiate_to_server=GAFM_pertub(self.intermidiate_to_server,self.discriminator)

        return outputs

    def server_backward(self):
        self.grad_to_client = self.intermidiate_to_server.grad.clone()
        return self.grad_to_client

    def train(self):
        self.server_model.train()

    def eval(self):
        self.server_model.eval()


class SplitNN_GAFM(torch.nn.Module):
    def __init__(self, clients, server, discriminator,
                 client_optimizers, server_optimizer, discriminator_optimizer, features
                 ):
        super().__init__()
        """class that expresses the whole architecture of SplitNN
        Args:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        Attributes:
            client (attack_splitnn.splitnn.Client):
            server (attack_splitnn.splitnn.Server):
            clietn_optimizer
            server_optimizer
        """
        self.clients = clients
        self.number = len(clients)
        self.server = server
        self.client_optimizers = client_optimizers
        self.server_optimizer = server_optimizer
        self.discriminator = discriminator
        self.discriminator_optimizer = discriminator_optimizer
        self.intermidiate_to_server = 0
        self.intermidiate_to_server_pertub = None
        self.grad_to_client = None
        self.grad_to_client_recon = None
        self.features = features

    def forward(self, inputs, labels,delta, Y_dot):
        # execute client - feed forward network
        # self.intermidiate_to_server = 0
        # self.features=features
        intermidiate_to_servers = []
        self.intermidiate_to_server = 0
        inter = int(inputs.shape[1] / self.number)

        self.labels = labels
        self.Y_dot = Y_dot
        self.delta=delta

        # execute server - feed forward netwoek
        # Decoder
        for i in range(self.number):
            # features=[0,8,16,23]
            client = self.clients[i]
            # print(i)
            # print(self.features)
            # print(self.features[i],self.features[i+1])
            input_data = inputs[:, self.features[i]:self.features[i + 1]]


            v = (client(input_data) / self.number)

            intermidiate_to_servers.append(v)
            self.intermidiate_to_server += v
            # print('self.intermidiate_to_server',self.intermidiate_to_server)
        self.intermidiate_to_server.retain_grad()
        outputs = self.server(self.intermidiate_to_server)

        return outputs, self.intermidiate_to_server, self.discriminator, intermidiate_to_servers

    def backward(self, standardization):
        # execute server - back propagation

        self.intermidiate_to_server_pertub = GAFM_pertub(self.intermidiate_to_server, self.discriminator,
                                                         self.server)  # self.server.server_backward()
        self.grad_to_client_recon = Pen_pertub(self.intermidiate_to_server, self.labels,self.delta, self.Y_dot)
        if standardization:
            self.intermidiate_to_server_pertub = self.intermidiate_to_server_pertub / (
                self.intermidiate_to_server_pertub.pow(2).sum().sqrt())
            self.grad_to_client_recon = self.grad_to_client_recon / (self.grad_to_client_recon.pow(2).sum().sqrt())

        self.intermidiate_to_server_pertub = gamma_w * self.intermidiate_to_server_pertub
        self.grad_to_client_recon = gamma * (
            self.grad_to_client_recon)  # +torch.randn(self.grad_to_client_recon.shape)/10)
        self.grad_to_client = self.grad_to_client_recon + self.intermidiate_to_server_pertub
        for i in range(self.number):
            client = self.clients[i]
            client.client_backward(self.grad_to_client / self.number)

        # self.client.client_backward(self.grad_to_client)

    def zero_grads(self):
        for i in range(self.number):
            self.client_optimizers[i].zero_grad()
        self.server_optimizer.zero_grad()
        self.discriminator_optimizer.zero_grad()

    def step(self):
        for i in range(self.number):
            self.client_optimizers[i].step()
        self.server_optimizer.step()
        # self.discriminator_optimizer.step()

    def train(self):
        for i in range(self.number):
            self.clients[i].train()
        self.server.train()
        self.discriminator.train()

    def eval(self):
        for i in range(self.number):
            self.clients[i].eval()
        self.server.eval()
        self.discriminator.eval()


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def train_GAFM_multiple(Epochs, features,train_loader,test_loader, gamma=1, gamma_w=1, sigma=0.01, delta=0.1,lr=1e-3 ,info=False, standardization=True,
               regenerate=False,mode='norandom'):
    clip = True
    clip_value = 0.1
    model_client_1 = FirstNet(input_dim=features[1] - features[0])
    model_client_1 = model_client_1.to(device)
    model_client_2 = FirstNet(input_dim=features[2] - features[1])
    model_client_2 = model_client_2.to(device)
    model_client_3 = FirstNet(input_dim=features[3] - features[2])
    model_client_3 = model_client_3.to(device)

    model_client_1.double()
    model_client_2.double()
    model_client_3.double()
    model_clients = [model_client_1, model_client_2, model_client_3]

    client1 = Client_GAFM(model_client_1)
    client2 = Client_GAFM(model_client_2)
    client3 = Client_GAFM(model_client_3)
    client = [client1, client2, client3]

    opt_c_1 = optim.Adam(model_client_1.parameters(), lr=lr)
    opt_c_2 = optim.Adam(model_client_2.parameters(), lr=lr)
    opt_c_3 = optim.Adam(model_client_3.parameters(), lr=lr)
    opt_c = [opt_c_1, opt_c_2, opt_c_3]

    model_2 = SecondNet()
    model_2 = model_2.to(device)
    model_2.double()
    opt_2 = optim.Adam(model_2.parameters(), lr=lr)
    server = Server_GAFM(model_2)

    BCE = nn.BCELoss()

    discriminator = DisNet()
    discriminator = discriminator.to(device)
    discriminator.double()
    opt_D = optim.Adam(discriminator.parameters(), lr=lr)
    # print('delta',delta)
    splitnn_GAFM = SplitNN_GAFM(client, server, discriminator, opt_c, opt_2, opt_D, features)

    splitnn_GAFM.train()
    train_auc_list = []
    grad_gan = []
    grad_recon = []
    grads = []
    for epoch in range(Epochs):
        if (regenerate):
            setup_seed(epoch)
        else:
            setup_seed(0)

        epoch_loss = 0
        epoch_outputs = []
        epoch_labels = []
        epoch_outputs_test = []
        epoch_labels_test = []
        epoch_g = []
        epoch_g_inner = []
        epoch_g_mean = []
        epoch_g_norm = []
        epoch_g1 = []
        epoch_g_inner1 = []
        epoch_g_mean1 = []
        epoch_g_norm1 = []
        epoch_g2 = []
        epoch_g_inner2 = []
        epoch_g_mean2 = []
        epoch_g_norm2 = []
        epoch_g3 = []
        epoch_g_inner3 = []
        epoch_g_mean3 = []
        epoch_g_norm3 = []
        epoch_intermediates = []
        epoch_intermediates_test = []

        for i, data in enumerate(train_loader):
            splitnn_GAFM.zero_grads()
            inputs, labels = data
            inputs = inputs.to(device).double()
            labels = labels.to(device).double()
            Y_ = addeNoise(sigma, labels.detach().numpy()).double()


            # deltas = torch.rand(labels.shape) * delta
            if mode == 'norandom':
                Y_dot = None
            elif mode == 'random_fix':
                deltas = torch.rand(labels.shape) * delta
                Y_dot = (abs(labels - 0.5 + deltas))
            else:
                Y_dot = torch.bernoulli(abs(labels - 0.5 + delta))
            # print(Y_)
            outputs, intermidiate_to_server, discriminator_model, intermidiate_to_servers = splitnn_GAFM(inputs, labels,delta,
                                                                                                         Y_dot)

            # print(outputs,intermidiate_to_server)
            loss_D = gamma_w * (-torch.mean(splitnn_GAFM.discriminator(Y_)) + torch.mean(
                splitnn_GAFM.discriminator(outputs.detach())))  # torch.mean
            loss_D.backward(retain_graph=True)
            splitnn_GAFM.discriminator_optimizer.step()

            if clip:
                for p in splitnn_GAFM.discriminator.parameters():
                    p.data.clamp_(-clip_value, clip_value)

            splitnn_GAFM.server_optimizer.zero_grad()
            splitnn_GAFM.discriminator.eval()
            loss_D_1 = -gamma_w * torch.mean(splitnn_GAFM.discriminator(outputs))

            loss_recon = gamma * BCE(intermidiate_to_server, torch.abs(labels - 0.5 + delta))
            loss = loss_recon + loss_D_1  # +loss_class
            # print('loss==gamma*loss_recon',loss==gamma*loss_recon)
            loss.backward()
            splitnn_GAFM.backward(standardization)
            splitnn_GAFM.step()

            epoch_loss += loss.item() / len(train_loader.dataset)
            epoch_outputs.append(outputs)
            epoch_labels.append(labels)

            g_norm, g_mean, g_inner = Attacks(splitnn_GAFM.grad_to_client, labels)
            epoch_g_norm.append(g_norm)
            epoch_g_mean.append(g_mean)
            epoch_g_inner.append(g_inner)
            epoch_g.append(splitnn_GAFM.grad_to_client)

            g_norm1, g_mean1, g_inner1 = Attacks(splitnn_GAFM.grad_to_client, labels)
            epoch_g_norm1.append(g_norm1)
            epoch_g_mean1.append(g_mean1)
            epoch_g_inner1.append(g_inner1)
            epoch_g1.append(intermidiate_to_server[0])

            g_norm2, g_mean2, g_inner2 = Attacks(splitnn_GAFM.grad_to_client, labels)
            epoch_g_norm2.append(g_norm2)
            epoch_g_mean2.append(g_mean2)
            epoch_g_inner2.append(g_inner2)
            epoch_g2.append(intermidiate_to_server[1])

            g_norm3, g_mean3, g_inner3 = Attacks(splitnn_GAFM.grad_to_client, labels)
            epoch_g_norm3.append(g_norm3)
            epoch_g_mean3.append(g_mean3)
            epoch_g_inner3.append(g_inner3)
            epoch_g3.append(intermidiate_to_server[2])

            t = next(iter(test_loader))
            outputs_test, intermidiate_to_server_test, _, _ = splitnn_GAFM(t[0], t[1],delta, Y_dot)
            labels_test = t[1]
            epoch_outputs_test.append(outputs_test)
            epoch_labels_test.append(labels_test)
            epoch_intermediates.append(intermidiate_to_server)
            epoch_intermediates_test.append(intermidiate_to_server_test)

        if gamma_w == 0:
            train_auc = torch_auc(torch.cat(epoch_labels),
                                  torch.cat(epoch_outputs))
            test_auc = torch_auc(torch.cat(epoch_labels_test),
                                 torch.cat(epoch_outputs_test))

        else:
            train_auc = torch_auc(torch.cat(epoch_labels),
                                  torch.cat(epoch_outputs))
            test_auc = torch_auc(torch.cat(epoch_labels_test),
                                 torch.cat(epoch_outputs_test))

        train_tvd = totalvaraition(torch.cat(epoch_labels),
                                   torch.cat(epoch_g))
        na_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_norm).view(-1, 1)))
        ma_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean).view(-1, 1)),
                          1 - torch_auc(torch.cat(epoch_labels),
                                        torch.cat(epoch_g_mean).view(-1, 1)))
        cos_leak_auc = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_inner).view(-1, 1)))

        na_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm1).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm1).view(-1, 1)))
        ma_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean1).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean1).view(-1, 1)))
        cos_leak_auc1 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner1).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner1).view(-1, 1)))

        na_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm2).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm2).view(-1, 1)))
        ma_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean2).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean2).view(-1, 1)))
        cos_leak_auc2 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner2).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner2).view(-1, 1)))

        na_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_norm3).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_norm3).view(-1, 1)))
        ma_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_mean3).view(-1, 1)),
                           1 - torch_auc(torch.cat(epoch_labels),
                                         torch.cat(epoch_g_mean3).view(-1, 1)))
        cos_leak_auc3 = max(torch_auc(torch.cat(epoch_labels), torch.cat(epoch_g_inner3).view(-1, 1)),
                            1 - torch_auc(torch.cat(epoch_labels),
                                          torch.cat(epoch_g_inner3).view(-1, 1)))
        train_auc_list.append(train_auc)
        grad_gan.append(splitnn_GAFM.intermidiate_to_server_pertub)
        grad_recon.append(splitnn_GAFM.grad_to_client_recon)
        grads.append(splitnn_GAFM.grad_to_client)
        if (epoch % 10 == 0 or epoch == Epochs - 1):
            print('Epoch', epoch, 'Training Loss', epoch_loss,
                  'Training AUC', train_auc,
                  'Testing AUC', test_auc,
                  'TVD', train_tvd,
                  'NA Leak AUC', na_leak_auc,
                  'MA Leak AUC', ma_leak_auc,
                  'Cos Leak AUC', cos_leak_auc
                  )
            print('Client1',
                  'NA Leak AUC', na_leak_auc1,
                  'MA Leak AUC', ma_leak_auc1,
                  'Cos Leak AUC', cos_leak_auc1
                  )
            print('Client2',
                  'NA Leak AUC', na_leak_auc2,
                  'MA Leak AUC', ma_leak_auc2,
                  'Cos Leak AUC', cos_leak_auc2
                  )
            print('Client3',
                  'NA Leak AUC', na_leak_auc3,
                  'MA Leak AUC', ma_leak_auc3,
                  'Cos Leak AUC', cos_leak_auc3
                  )
    return train_auc, test_auc, train_tvd, na_leak_auc, ma_leak_auc, cos_leak_auc, na_leak_auc1, ma_leak_auc1, cos_leak_auc1, na_leak_auc2, ma_leak_auc2, cos_leak_auc2, na_leak_auc3, ma_leak_auc3, cos_leak_auc3, splitnn_GAFM


